{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据输入（三） 不定长序列数据\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')  # 不打印 warning \n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "# 设置GPU按需增长\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.打包 tf.data.Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 最简单的例子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0]\n",
      " [1]\n",
      " [2]\n",
      " [3]\n",
      " [4]]\n"
     ]
    }
   ],
   "source": [
    "X = np.arange(20).reshape([10, 2])\n",
    "y = np.arange(10).reshape([10,1])\n",
    "print(y[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[文档](https://www.tensorflow.org/programmers_guide/datasets)中说：\n",
    "\n",
    ">要启动输入管道，您必须定义来源。例如，要通过内存中的某些张量构建 Dataset，您可以使用 tf.data.Dataset.from_tensors() 或 tf.data.Dataset.from_tensor_slices()。或者，如果您的输入数据以推荐的 TFRecord 格式存储在磁盘上，那么您可以构建 tf.data.TFRecordDataset\n",
    "\n",
    "现在我们先来看看 from_tensor_slices() 的用法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<TensorSliceDataset shapes: ((2,), (1,)), types: (tf.int64, tf.int64)>\n"
     ]
    }
   ],
   "source": [
    "dataset = tf.data.Dataset.from_tensor_slices((X, y))\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这样我们就把 X,y 构建成了一个 dataset，可以看到 X 和 y 的维度，但是并不能看到样本的数量。\n",
    "\n",
    "下一步，我们要取出 dataset 中的数据，需要创建一个迭代器。tf 提供了好多种创建迭代器的方式，我发现目前来说使用最简单的方法就基本能够满足要求了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7f5fbc995f28>\n"
     ]
    }
   ],
   "source": [
    "iterator = dataset.make_one_shot_iterator()\n",
    "print(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "count = 0 : y = [0]\n",
      "count = 1 : y = [1]\n",
      "count = 2 : y = [2]\n",
      "count = 3 : y = [3]\n",
      "count = 4 : y = [4]\n",
      "count = 5 : y = [5]\n",
      "count = 6 : y = [6]\n",
      "count = 7 : y = [7]\n",
      "count = 8 : y = [8]\n",
      "count = 9 : y = [9]\n",
      "\n",
      " End of sequence\n",
      "\t [[Node: IteratorGetNext_10 = IteratorGetNext[output_shapes=[[2], [1]], output_types=[DT_INT64, DT_INT64], _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](OneShotIterator)]]\n",
      "\n",
      "Caused by op 'IteratorGetNext_10', defined at:\n",
      "  File \"/usr/lib/python3.5/runpy.py\", line 184, in _run_module_as_main\n",
      "    \"__main__\", mod_spec)\n",
      "  File \"/usr/lib/python3.5/runpy.py\", line 85, in _run_code\n",
      "    exec(code, run_globals)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py\", line 16, in <module>\n",
      "    app.launch_new_instance()\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/traitlets/config/application.py\", line 658, in launch_instance\n",
      "    app.start()\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/ipykernel/kernelapp.py\", line 486, in start\n",
      "    self.io_loop.start()\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/tornado/platform/asyncio.py\", line 127, in start\n",
      "    self.asyncio_loop.run_forever()\n",
      "  File \"/usr/lib/python3.5/asyncio/base_events.py\", line 345, in run_forever\n",
      "    self._run_once()\n",
      "  File \"/usr/lib/python3.5/asyncio/base_events.py\", line 1312, in _run_once\n",
      "    handle._run()\n",
      "  File \"/usr/lib/python3.5/asyncio/events.py\", line 125, in _run\n",
      "    self._callback(*self._args)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/tornado/platform/asyncio.py\", line 117, in _handle_events\n",
      "    handler_func(fileobj, events)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/tornado/stack_context.py\", line 276, in null_wrapper\n",
      "    return fn(*args, **kwargs)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/zmq/eventloop/zmqstream.py\", line 450, in _handle_events\n",
      "    self._handle_recv()\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/zmq/eventloop/zmqstream.py\", line 480, in _handle_recv\n",
      "    self._run_callback(callback, msg)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/zmq/eventloop/zmqstream.py\", line 432, in _run_callback\n",
      "    callback(*args, **kwargs)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/tornado/stack_context.py\", line 276, in null_wrapper\n",
      "    return fn(*args, **kwargs)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py\", line 283, in dispatcher\n",
      "    return self.dispatch_shell(stream, msg)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py\", line 233, in dispatch_shell\n",
      "    handler(stream, idents, msg)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py\", line 399, in execute_request\n",
      "    user_expressions, allow_stdin)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/ipykernel/ipkernel.py\", line 208, in do_execute\n",
      "    res = shell.run_cell(code, store_history=store_history, silent=silent)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/ipykernel/zmqshell.py\", line 537, in run_cell\n",
      "    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py\", line 2662, in run_cell\n",
      "    raw_cell, store_history, silent, shell_futures)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py\", line 2785, in _run_cell\n",
      "    interactivity=interactivity, compiler=compiler, result=result)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py\", line 2903, in run_ast_nodes\n",
      "    if self.run_code(code, result):\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py\", line 2963, in run_code\n",
      "    exec(code_obj, self.user_global_ns, self.user_ns)\n",
      "  File \"<ipython-input-5-246ed8c2d123>\", line 8, in <module>\n",
      "    x_batch, y_batch = sess.run(iterator.get_next())\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/iterator_ops.py\", line 366, in get_next\n",
      "    name=name)), self._output_types,\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_dataset_ops.py\", line 1455, in iterator_get_next\n",
      "    output_shapes=output_shapes, name=name)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py\", line 787, in _apply_op_helper\n",
      "    op_def=op_def)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py\", line 3290, in create_op\n",
      "    op_def=op_def)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py\", line 1654, in __init__\n",
      "    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access\n",
      "\n",
      "OutOfRangeError (see above for traceback): End of sequence\n",
      "\t [[Node: IteratorGetNext_10 = IteratorGetNext[output_shapes=[[2], [1]], output_types=[DT_INT64, DT_INT64], _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](OneShotIterator)]]\n",
      "\n",
      "final count = 10\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "count = 0\n",
    "try:\n",
    "    while True:\n",
    "        x_batch, y_batch = sess.run(iterator.get_next())\n",
    "        print('count = {} : y = {}'.format(count, y_batch))\n",
    "        count += 1\n",
    "except Exception as e:\n",
    "    print('\\n', e)\n",
    "    print('final count = {}'.format(count))\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从上面的结果可以看到，我们使用 `iterator = dataset.make_one_shot_iterator()` 的方式创建的迭代器可以每次取出一个样本， one_shot 表示只能迭代一次，迭代完了以后再取值就会报错了。\n",
    "\n",
    "### 2.2 相关功能\n",
    "\n",
    "在实际应用中，tf 提供了一下几个函数来实现我们的要求：下面摘自 https://zhuanlan.zhihu.com/p/30751039\n",
    "\n",
    "\n",
    "**map**\n",
    "\n",
    "map接收一个函数，Dataset中的每个元素都会被当作这个函数的输入，并将函数返回值作为新的Dataset，如我们可以对dataset中每个元素的值加1：\n",
    "```\n",
    "dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))\n",
    "dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0\n",
    "```\n",
    "\n",
    "**batch**\n",
    "\n",
    "batch就是将多个元素组合成batch，如下面的程序将dataset中的每个元素组成了大小为32的batch：\n",
    "```\n",
    "dataset = dataset.batch(32)\n",
    "```\n",
    "\n",
    "**shuffle**\n",
    "\n",
    "shuffle的功能为打乱dataset中的元素，它有一个参数buffersize，表示打乱时使用的buffer的大小：\n",
    "```\n",
    "dataset = dataset.shuffle(buffer_size=10000)\n",
    "```\n",
    "\n",
    "**repeat**\n",
    "\n",
    "repeat的功能就是将整个序列重复多次，主要用来处理机器学习中的epoch，假设原先的数据是一个epoch，使用repeat(5)就可以将之变成5个epoch：\n",
    "```\n",
    "dataset = dataset.repeat(5)\n",
    "```\n",
    "如果**直接调用repeat()的话，生成的序列就会无限重复下去，没有结束**，因此也不会抛出tf.errors.OutOfRangeError异常：\n",
    "```\n",
    "dataset = dataset.repeat()\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0 1]\n",
      " [2 3]\n",
      " [4 5]\n",
      " [6 7]\n",
      " [8 9]]\n",
      "[[0]\n",
      " [1]\n",
      " [2]\n",
      " [3]\n",
      " [4]]\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')  # 不打印 warning \n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "# 设置GPU按需增长\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "import numpy as np\n",
    "import sys\n",
    "\n",
    "\n",
    "X = np.arange(20).reshape([10, 2])\n",
    "y = np.arange(10).reshape([10,1])\n",
    "print(X[:5])\n",
    "print(y[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面我们要实现每次取出 6 个值，并且把 x 中每个元素×2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?, 2), (?, 1)), types: (tf.int64, tf.int64)>\n",
      "<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7fad43f70e48>\n"
     ]
    }
   ],
   "source": [
    "def pre_func(X, y):\n",
    "    X = tf.multiply(X, 2)\n",
    "    return X, y\n",
    "\n",
    "batch_size = 6\n",
    "# 创建 dataset\n",
    "dataset = tf.data.Dataset.from_tensor_slices((X, y))\n",
    "\n",
    "# 对数据进行操作\n",
    "dataset = dataset.map(pre_func)\n",
    "dataset = dataset.repeat(3).batch(4)\n",
    "print(dataset)\n",
    "\n",
    "# 生成迭代器\n",
    "iterator = dataset.make_one_shot_iterator()\n",
    "print(iterator)\n",
    "\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "# 迭代取值\n",
    "count = 0\n",
    "try:\n",
    "    while True:\n",
    "        X_batch, y_batch = sess.run(iterator.get_next())\n",
    "        print('count = {} : X = {}'.format(count, X_batch))\n",
    "        count += 1\n",
    "except Exception as e:\n",
    "    print('final count = {}'.format(count))\n",
    "  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面添加 shuffle 功能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BatchDataset shapes: ((?, 2), (?, 1)), types: (tf.int64, tf.int64)>\n",
      "<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7fad40054b38>\n",
      "count = 0 : y = [1 9 6 7]\n",
      "count = 1 : y = [3 5 0 8]\n",
      "count = 2 : y = [2 4 3 6]\n",
      "count = 3 : y = [7 2 1 0]\n",
      "count = 4 : y = [5 8 9 4]\n",
      "count = 5 : y = [7 6 2 9]\n",
      "count = 6 : y = [0 3 8 1]\n",
      "count = 7 : y = [4 5 2 0]\n",
      "count = 8 : y = [6 9 4 5]\n",
      "count = 9 : y = [8 3 7 1]\n"
     ]
    }
   ],
   "source": [
    "def pre_func(X, y):\n",
    "    X = tf.multiply(X, 2)\n",
    "    return X, y\n",
    "\n",
    "batch_size = 6\n",
    "# 创建 dataset\n",
    "dataset = tf.data.Dataset.from_tensor_slices((X, y))\n",
    "\n",
    "# 对数据进行操作\n",
    "dataset = dataset.map(pre_func)\n",
    "dataset = dataset.shuffle(buffer_size=20)  # 设置 buffle_size 越大，shuffle 越均匀\n",
    "dataset = dataset.repeat().batch(4)\n",
    "print(dataset)\n",
    "\n",
    "# 生成迭代器\n",
    "iterator = dataset.make_one_shot_iterator()\n",
    "print(iterator)\n",
    "\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "# 迭代取值\n",
    "count = 0\n",
    "for count in range(10):\n",
    "    X_batch, y_batch = sess.run(iterator.get_next())\n",
    "    print('count = {} : y = {}'.format(count, y_batch.reshape([-1])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从上面的结果来看，对于能够一次性读入内存中的数据。dataset 具备了非常好的功能：\n",
    "- 使用 map 函数可以对数据进行操作\n",
    "- 使用 shuffle 能够很好地扰乱数据，如果 先 shuffle 再 repeat 的话，可以每次都无重复地迭代完一个 epoch 再到下一个epoch；如果是先 repeat 再 shuffle 的话可能每个batch中会出现相同的样本。\n",
    "- 使用 repeat 函数可以不断地取样本\n",
    "- 使用 batch 函数可以很方便的读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')  # 不打印 warning \n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "# 设置GPU按需增长\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "import numpy as np\n",
    "import sys\n",
    "\n",
    "\n",
    "X = [[1,2,3,4], [5,6],[7],[8,9,10]]\n",
    "y = [4,2,1,3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Argument must be a dense tensor: [[1, 2, 3, 4], [5, 6], [7], [8, 9, 10]] - got shape [4], but wanted [4, 4].",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-9335104fe59e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;31m# 创建 dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_tensor_slices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;31m# 对数据进行操作\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36mfrom_tensor_slices\u001b[0;34m(tensors)\u001b[0m\n\u001b[1;32m    220\u001b[0m       \u001b[0mDataset\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mA\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mDataset\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    221\u001b[0m     \"\"\"\n\u001b[0;32m--> 222\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mTensorSliceDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    223\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    224\u001b[0m   \u001b[0;34m@\u001b[0m\u001b[0mstaticmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, tensors)\u001b[0m\n\u001b[1;32m   1015\u001b[0m           if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(\n\u001b[1;32m   1016\u001b[0m               t, name=\"component_%d\" % i)\n\u001b[0;32m-> 1017\u001b[0;31m           \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1018\u001b[0m       ])\n\u001b[1;32m   1019\u001b[0m       \u001b[0mflat_tensors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m   1015\u001b[0m           if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(\n\u001b[1;32m   1016\u001b[0m               t, name=\"component_%d\" % i)\n\u001b[0;32m-> 1017\u001b[0;31m           \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1018\u001b[0m       ])\n\u001b[1;32m   1019\u001b[0m       \u001b[0mflat_tensors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py\u001b[0m in \u001b[0;36mconvert_to_tensor\u001b[0;34m(value, dtype, name, preferred_dtype)\u001b[0m\n\u001b[1;32m    948\u001b[0m       \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    949\u001b[0m       \u001b[0mpreferred_dtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpreferred_dtype\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 950\u001b[0;31m       as_ref=False)\n\u001b[0m\u001b[1;32m    951\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    952\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py\u001b[0m in \u001b[0;36minternal_convert_to_tensor\u001b[0;34m(value, dtype, name, as_ref, preferred_dtype, ctx)\u001b[0m\n\u001b[1;32m   1038\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1039\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mret\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1040\u001b[0;31m       \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconversion_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mas_ref\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mas_ref\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1041\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1042\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mret\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mNotImplemented\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/constant_op.py\u001b[0m in \u001b[0;36m_constant_tensor_conversion_function\u001b[0;34m(v, dtype, name, as_ref)\u001b[0m\n\u001b[1;32m    233\u001b[0m                                          as_ref=False):\n\u001b[1;32m    234\u001b[0m   \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mas_ref\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 235\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mconstant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    236\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    237\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/constant_op.py\u001b[0m in \u001b[0;36mconstant\u001b[0;34m(value, dtype, shape, name, verify_shape)\u001b[0m\n\u001b[1;32m    212\u001b[0m   tensor_value.tensor.CopyFrom(\n\u001b[1;32m    213\u001b[0m       tensor_util.make_tensor_proto(\n\u001b[0;32m--> 214\u001b[0;31m           value, dtype=dtype, shape=shape, verify_shape=verify_shape))\n\u001b[0m\u001b[1;32m    215\u001b[0m   \u001b[0mdtype_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mattr_value_pb2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAttrValue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtensor_value\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    216\u001b[0m   const_tensor = g.create_op(\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_util.py\u001b[0m in \u001b[0;36mmake_tensor_proto\u001b[0;34m(values, dtype, shape, verify_shape)\u001b[0m\n\u001b[1;32m    440\u001b[0m                          \u001b[0;34m\"\"\" - got shape %s, but wanted %s.\"\"\"\u001b[0m \u001b[0;34m%\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    441\u001b[0m                          (values, list(nparray.shape),\n\u001b[0;32m--> 442\u001b[0;31m                           _GetDenseDimensions(values)))\n\u001b[0m\u001b[1;32m    443\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    444\u001b[0m     \u001b[0;31m# python/numpy default float type is float64. We prefer float32 instead.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Argument must be a dense tensor: [[1, 2, 3, 4], [5, 6], [7], [8, 9, 10]] - got shape [4], but wanted [4, 4]."
     ]
    }
   ],
   "source": [
    "def pre_func(X, y):\n",
    "    X = tf.multiply(X, 2)\n",
    "    return X, y\n",
    "\n",
    "batch_size = 6\n",
    "# 创建 dataset\n",
    "dataset = tf.data.Dataset.from_tensor_slices((X, y))\n",
    "\n",
    "# 对数据进行操作\n",
    "dataset = dataset.map(pre_func)\n",
    "dataset = dataset.shuffle(buffer_size=20)  # 设置 buffle_size 越大，shuffle 越均匀\n",
    "dataset = dataset.repeat().batch(4)\n",
    "print(dataset)\n",
    "\n",
    "# 生成迭代器\n",
    "iterator = dataset.make_one_shot_iterator()\n",
    "print(iterator)\n",
    "\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "# 迭代取值\n",
    "count = 0\n",
    "for count in range(10):\n",
    "    X_batch, y_batch = sess.run(iterator.get_next())\n",
    "    print('count = {} : y = {}'.format(count, y_batch.reshape([-1])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
